package com.wt.demo;

import java.util.ArrayList;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

/**
 * @description
 * @author: wangtao
 * @date:11:34 2018/6/4
 * @email:tao8.wang@changhong.com
 */
public class ForkJoinCountTask extends RecursiveTask<Long> {

    private static final int THRESHOLD = 10000;
    private long start;
    private long end;

    public ForkJoinCountTask(long start, long end) {
        this.start = start;
        this.end = end;
    }

    protected Long compute() {
        long sum = 0;
        // 如果计算的个数没有超过阈值，直接开始计算
        if ((end - start) <= THRESHOLD) {
            for (long i = start; i <= end; i++) {
                sum += i;
            }
        } else {
            // 超过了阈值，则分成100个小任务
            long forkCount = (end - start) / THRESHOLD;
            if (THRESHOLD * forkCount < end) {
                forkCount++;
            }
            ArrayList<ForkJoinCountTask> tasks = new ArrayList<ForkJoinCountTask>();
            long pos = start;
            for (int i = 0; i < forkCount; i++) {
                long lastOne = pos + THRESHOLD;
                lastOne = lastOne > end ? end : lastOne;
                ForkJoinCountTask task = new ForkJoinCountTask(pos, lastOne);
                pos += THRESHOLD + 1;
                tasks.add(task);
                task.fork();
            }
            // 对小任务的结果进行求和
            for (ForkJoinCountTask t : tasks) {
                sum += t.join();
            }
        }
        // 返回部分结果
        return sum;
    }

    public static void main(String[] args) {
        ForkJoinPool pool = new ForkJoinPool();
        ForkJoinCountTask t = new ForkJoinCountTask(0, 20000);
        ForkJoinTask<Long> fjt = pool.submit(t);
        try {
            long res = fjt.get();
            System.out.println(res);
        } catch (InterruptedException e) {
            e.printStackTrace();
        } catch (ExecutionException e) {
            e.printStackTrace();
        }
    }
}
